## Regression models of placement outcomes
library(tidyverse)
## ── Attaching packages ────────────────────────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 3.1.0     ✔ purrr   0.2.5
## ✔ tibble  1.4.2     ✔ dplyr   0.7.8
## ✔ tidyr   0.8.2     ✔ stringr 1.3.1
## ✔ readr   1.1.1     ✔ forcats 0.3.0
## ── Conflicts ───────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(broom)
library(forcats)
library(rstanarm)
## Loading required package: Rcpp
## rstanarm (Version 2.18.2, packaged: 2018-11-08 22:19:38 UTC)
## - Do not expect the default priors to remain the same in future rstanarm versions.
## Thus, R scripts should specify priors explicitly, even if they are just the defaults.
## - For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores())
## - Plotting theme set to bayesplot::theme_default().
options(mc.cores = parallel::detectCores() - 2)
## bayesplot makes itself the default theme
theme_set(theme_minimal())

library(tictoc)
library(assertthat)
## 
## Attaching package: 'assertthat'
## The following object is masked from 'package:tibble':
## 
##     has_name
## Suppress messages when generating HTML file
knitr::opts_chunk$set(message = FALSE)

## broom.mixed supposedly has an augment() method for rstanarm (stangreg objects); 
## but I can't find its documentation to configure its options. 
predictions = function (model, lower = .055, upper = 1 - lower) {
    predictions_mx = posterior_linpred(model)
    medians = apply(predictions_mx, 2, median)
    lowers = apply(predictions_mx, 2, quantile, probs = lower)
    uppers = apply(predictions_mx, 2, quantile, probs = upper)
    
    dataf = tibble(.median = medians, 
                   .lower = lowers, 
                   .upper = uppers)
    dataf = mutate_all(dataf, gtools::inv.logit)
    dataf$.residual = model$residuals
    dataf = cbind(dataf, model$data)
    return(dataf)
}

## Extract effects estimates, w/ nice orderings on variables for plotting
posterior_estimates = function(model) {
    fixed = tidy(model, parameters = 'non-varying', intervals = TRUE)
    varying = tidy(model, parameters = 'varying', intervals = TRUE)
    
    combined = suppressWarnings(bind_rows(fixed, varying))
    
    combined = combined %>% 
        ## Program- vs individual-level variables
        mutate(entity = case_when(term == '(Intercept)' & is.na(group) ~ 'intercept', 
                                  str_detect(term, 'prestige') ~ 'program',
                                  str_detect(term, 'gender') ~ 'individual', 
                                  str_detect(term, 'country') ~ 'program', 
                                  term == 'perc_w' ~ 'program', 
                                  term == 'total_placements' ~ 'program',
                                  group == 'community' ~ 'program',
                                  group == 'graduation_year' ~ 'individual',
                                  group == 'placement_year' ~ 'individual',
                                  group == 'aos_category' ~ 'individual', 
                                  term == 'aos_diversity' ~ 'program',
                                  term == 'log10(in_centrality)' ~ 'program',
                                  TRUE ~ NA_character_)) %>% 
        ## Variable groups
        mutate(group = case_when(!is.na(group) ~ as.character(group), 
                                 term == '(Intercept)' ~ 'intercept', 
                                 str_detect(term, 'prestige') ~ 'prestige', 
                                 str_detect(term, 'gender') ~ 'gender', 
                                 str_detect(term, 'country') ~ 'country', 
                                 term == 'perc_w' ~ 'continuous', 
                                 term == 'total_placements' ~ 'continuous', 
                                 term == 'aos_diversity' ~ 'continuous',
                                 term == 'log10(in_centrality)' ~ 'continuous',
                                 TRUE ~ NA_character_)) %>% 
        ## Nicer labels
        mutate(level = ifelse(!is.na(level), as.character(level), term), 
               level = str_replace_all(level, '_', ' '), 
               level = str_remove(level, '^prestige'), 
               level = case_when(level == 'genderw' ~ 'gender: woman', 
                                 level == 'gendero' ~ 'gender: other', 
                                 TRUE ~ level),
               level = str_remove(level, 'country'), 
               level = str_replace(level, 'perc w', 'women (%)')) %>% 
        ## Arrange x-axis
        arrange(entity, group, estimate) %>% 
        mutate(group = fct_inorder(group), 
               level = fct_inorder(level), 
               ## Fix year ordering
               level = fct_relevel(level, rev(c('2012', '2013', '2014', 
                                   '2015', '2016', '2017', '2018', 
                                   '2019')))) %>%
        ## Backtransform estimates
        mutate_if(is.numeric, exp)
    
    assert_that(!any(is.na(combined$entity)), msg = 'NA values for entity')
    return(combined)
}
data_folder = '../data/'

load(str_c(data_folder, '01_parsed.Rdata'))
univ_df = read_rds(str_c(data_folder, '02_univ_net_stats.rds'))

individual_df = individual_df %>%
    left_join(univ_df, by = c('placing_univ_id' = 'univ_id')) %>%
    ## Use the canonical names from univ_df
    select(-placing_univ) %>%
    ## Drop NAs
    # filter(complete.cases(.))
    filter_at(vars('permanent', 'aos_category', 
                   'graduation_year', 'prestige', 'community', 
                   'gender', 'frac_w', 
                   'frac_high_prestige', 'total_placements'), 
              all_vars(negate(is.na)(.))) %>%
    mutate(perc_w = 100*frac_w, 
           perc_high_prestige = 100*frac_high_prestige)

## Variables to consider: aos_category; graduation_year; placement_year; prestige; out_centrality; cluster; community; placing_univ_id; gender; country; perc_w; total_placements

## Giant pairs plot/correlogram
## perc_high_prestige, out_centrality, and prestige are all tightly correlated
## All other pairs have low to moderate correlation
individual_df %>% 
    select(permanent, aos_category, aos_diversity, perc_high_prestige,
           graduation_year, placement_year, prestige, 
           in_centrality, out_centrality, community, gender, country, perc_w, 
           total_placements) %>% 
    mutate_if(negate(is.numeric), function(x) as.integer(as.factor(x))) %>% 
    mutate_at(vars(in_centrality, out_centrality), log10) %>% 
    # GGally::ggpairs()
    cor() %>% 
    as_tibble(rownames = 'Var1') %>% 
    gather(key = 'Var2', value = 'cor', -Var1) %>% 
    ggplot(aes(Var1, Var2, fill = cor)) +
    geom_tile() +
    geom_text(aes(label = round(cor, digits = 2)), 
              color = 'white') +
    scale_fill_gradient2()

## No indication that AOS diversity has any effect
ggplot(individual_df, aes(aos_diversity, 1*permanent)) + 
    geom_point() +
    geom_smooth(method = 'loess')

## And not for fraction of PhDs awarded to women women, either
ggplot(individual_df, aes(frac_w, 1*permanent)) +
    geom_point() +
    geom_smooth(method = 'loess')

## Descriptive statistics ----
individual_df %>%
    select(permanent, aos_category, 
           graduation_year, gender) %>%
    gather(key = variable, value = value) %>%
    count(variable, value)
## Warning: attributes are not identical across measure variables;
## they will be dropped
## # A tibble: 14 x 3
##    variable        value                        n
##    <chr>           <chr>                    <int>
##  1 aos_category    History and Traditions     525
##  2 aos_category    LEMM                       600
##  3 aos_category    Science, Logic, and Math   310
##  4 aos_category    Value Theory               696
##  5 gender          m                         1521
##  6 gender          o                            1
##  7 gender          w                          609
##  8 graduation_year 2012                       451
##  9 graduation_year 2013                       435
## 10 graduation_year 2014                       443
## 11 graduation_year 2015                       413
## 12 graduation_year 2016                       389
## 13 permanent       FALSE                      980
## 14 permanent       TRUE                      1151
individual_df %>%
    select(prestige, country) %>%
    gather(key = variable, value = value) %>%
    count(variable, value)
## # A tibble: 16 x 3
##    variable value             n
##    <chr>    <chr>         <int>
##  1 country  Australia        65
##  2 country  Belgium          55
##  3 country  Canada          153
##  4 country  France            9
##  5 country  Germany           3
##  6 country  Greece            1
##  7 country  Hungary           8
##  8 country  Ireland           2
##  9 country  Netherlands       7
## 10 country  New Zealand       7
## 11 country  Norway            1
## 12 country  Sweden            1
## 13 country  U.K.            298
## 14 country  U.S.           1521
## 15 prestige high-prestige  1058
## 16 prestige low-prestige   1073
individual_df %>%
    select(frac_w, total_placements, perm_placement_rate) %>%
    gather(key = variable, value = value) %>%
    group_by(variable) %>%
    summarize_at(vars(value), funs(min, max, mean, median, sd), 
                 na.rm = TRUE)
## # A tibble: 3 x 6
##   variable              min   max   mean median     sd
##   <chr>               <dbl> <dbl>  <dbl>  <dbl>  <dbl>
## 1 frac_w                  0     1  0.280  0.265  0.147
## 2 perm_placement_rate     0     1  0.527  0.536  0.179
## 3 total_placements        1    73 23.3   21     14.4
## Model -----
model_file = str_c(data_folder, '03_model.Rds')
if (!file.exists(model_file)) {
    ## ~400 seconds
    tic()
    model = individual_df %>% 
        mutate(prestige = fct_relevel(prestige, 'low-prestige'), 
               country = fct_relevel(country, 'U.S.')) %>% 
        stan_glmer(formula = permanent ~ 
                       (1|aos_category) +
                       gender + 
                       (1|graduation_year) +
                       (1|placement_year) +
                       1 +
                       aos_diversity +
                       (1|community) +
                       log10(in_centrality) +
                       total_placements +
                       perc_w +
                       country +
                       prestige,
                   family = 'binomial',
                   ## Priors
                   ## Constant and coefficients
                   prior_intercept = normal(0, 1), 
                   prior = normal(0, .5),
                   ## error sd
                   prior_aux = exponential(rate = 1, 
                                           autoscale = TRUE),
                   ## random effects covariance
                   prior_covariance =  decov(regularization = 1, 
                                             concentration = 1, 
                                             shape = 1, scale = 1),
                   seed = 1159518215,
                   adapt_delta = .99,
                   chains = 2, iter = 4000)
    toc()
    write_rds(model, model_file)
} else {
    model = read_rds(model_file)
}
## Check ESS and Rhat
## Rhats all look good.  ESS a little low for grad years + some sigmas
model %>%
    summary() %>%
    as.data.frame() %>%
    rownames_to_column('parameter') %>%
    select(parameter, n_eff, Rhat) %>%
    # knitr::kable()
    ggplot(aes(n_eff, Rhat, label = parameter)) +
    geom_point() +
    geom_vline(xintercept = 2000) +
    geom_hline(yintercept = 1.10)

if (require(plotly)) {
    plotly::ggplotly()    
}
## Check predictions
pp_check(model, nreps = 200)

pp_check(model, nreps = 200, plotfun = 'ppc_bars')

## <https://arxiv.org/pdf/1605.01311.pdf>
pp_check(model, nreps = 200, plotfun = 'ppc_rootogram')

pp_check(model, nreps = 200, plotfun = 'ppc_rootogram', style = 'hanging')

pred = predictions(model)

ggplot(pred, aes(.median, .residual)) +
    geom_point(aes(color = permanent)) +
    geom_smooth()

ggplot(pred, aes(.median, .residual, color = permanent)) +
    geom_linerange(aes(ymin = .residual - .lower, 
                       ymax = .residual - .upper))

ggplot(pred, aes(permanent, .median)) +
    geom_violin(draw_quantiles = c(.05, .5, .95))

posterior_estimates(model) %>% 
    filter(entity != 'intercept', 
           group != 'placement_year') %>% 
    ggplot(aes(x = level, y = estimate, 
           ymin = lower, ymax = upper, 
           color = group)) +
    geom_hline(yintercept = 1, linetype = 'dashed') +
    geom_pointrange() + 
    scale_color_viridis_d(name = 'covariate\ngroup') +
    xlab('') + ylab('') +
    coord_flip() +
    facet_wrap(~ entity, scales = 'free')

ggplot(posterior_estimates(model), 
       aes(x = level, y = estimate, 
           ymin = lower, ymax = upper, 
           color = entity)) +
    geom_hline(yintercept = 1, linetype = 'dashed') +
    geom_pointrange() +
    coord_flip() +
    facet_wrap(~ group, scales = 'free_y')

## Playing around with ridgeline plots
# library(ggridges)
# model %>%
#     as_tibble() %>%
#     gather(key = parameter, value = value) %>%
#     filter(str_detect(parameter, 'aos_category') & 
#                !str_detect(parameter, 'Sigma')) %>%
#     mutate(parameter = str_remove(parameter, 'b\\[.*:'), 
#            parameter = str_remove(parameter, '\\]'), 
#            parameter = str_replace_all(parameter, '_', ' ')) %>%
#     ggplot(aes(value, y = parameter, group = parameter)) + 
#     geom_density_ridges(scale = 1.5, 
#                         rel_min_height = .01)
# 
# 
# ## Posterior for diff between high- and low-prestige
# model_5 %>%
#     as_tibble() %>%
#     select(high_prestige = `b[(Intercept) prestige:high-prestige]`, 
#            low_prestige = `b[(Intercept) prestige:low-prestige]`) %>%
#     mutate(diff = high_prestige - low_prestige) %>%
#     mutate_all(exp) %>%
#     ggplot(aes(diff)) + 
#     geom_density() + 
#     geom_vline(aes(xintercept = median(diff))) +
#     scale_x_continuous(#trans = scales::exp_trans(), 
#         # breaks = scales::log_breaks(),
#         # labels = scales::number_format(accuracy = .1), 
#         name = 'Odds Ratio')
# ## And between genders
# model_5 %>%
#     as_tibble() %>%
#     select(m = `b[(Intercept) gender:m]`, 
#            w = `b[(Intercept) gender:w]`) %>%
#     mutate(diff = w - m) %>%
#     mutate_all(exp) %>%
#     ggplot(aes(diff)) + 
#     geom_density() + 
#     geom_vline(aes(xintercept = median(diff))) + 
#     scale_x_continuous(name = 'Odds Ratio')
sessionInfo()
## R version 3.5.1 (2018-07-02)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS  10.14.2
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] plotly_4.8.0     bindrcpp_0.2.2   assertthat_0.2.0 tictoc_1.0      
##  [5] rstanarm_2.18.2  Rcpp_1.0.0       broom_0.5.0      forcats_0.3.0   
##  [9] stringr_1.3.1    dplyr_0.7.8      purrr_0.2.5      readr_1.1.1     
## [13] tidyr_0.8.2      tibble_1.4.2     ggplot2_3.1.0    tidyverse_1.2.1 
## 
## loaded via a namespace (and not attached):
##  [1] nlme_3.1-137       matrixStats_0.54.0 xts_0.11-2        
##  [4] lubridate_1.7.4    threejs_0.3.1      httr_1.3.1        
##  [7] rprojroot_1.3-2    rstan_2.18.2       tools_3.5.1       
## [10] backports_1.1.2    utf8_1.1.4         R6_2.3.0          
## [13] DT_0.5             mgcv_1.8-25        lazyeval_0.2.1    
## [16] colorspace_1.3-2   withr_2.1.2        tidyselect_0.2.5  
## [19] gridExtra_2.3      prettyunits_1.0.2  processx_3.2.0    
## [22] compiler_3.5.1     cli_1.0.1          rvest_0.3.2       
## [25] xml2_1.2.0         shinyjs_1.0        labeling_0.3      
## [28] colourpicker_1.0   scales_1.0.0       dygraphs_1.1.1.6  
## [31] ggridges_0.5.1     callr_3.0.0        digest_0.6.18     
## [34] StanHeaders_2.18.0 minqa_1.2.4        rmarkdown_1.10    
## [37] base64enc_0.1-3    pkgconfig_2.0.2    htmltools_0.3.6   
## [40] lme4_1.1-19        htmlwidgets_1.3    rlang_0.3.0.1     
## [43] readxl_1.1.0       rstudioapi_0.8     shiny_1.2.0       
## [46] bindr_0.1.1        zoo_1.8-4          jsonlite_1.5      
## [49] crosstalk_1.0.0    gtools_3.8.1       inline_0.3.15     
## [52] magrittr_1.5       loo_2.0.0          bayesplot_1.6.0   
## [55] Matrix_1.2-15      fansi_0.4.0        munsell_0.5.0     
## [58] stringi_1.2.4      yaml_2.2.0         MASS_7.3-51.1     
## [61] pkgbuild_1.0.2     plyr_1.8.4         grid_3.5.1        
## [64] parallel_3.5.1     promises_1.0.1     crayon_1.3.4      
## [67] miniUI_0.1.1.1     lattice_0.20-38    splines_3.5.1     
## [70] haven_1.1.2        hms_0.4.2          knitr_1.20        
## [73] ps_1.2.1           pillar_1.3.0       igraph_1.2.2      
## [76] markdown_0.8       shinystan_2.5.0    codetools_0.2-15  
## [79] reshape2_1.4.3     stats4_3.5.1       rstantools_1.5.1  
## [82] glue_1.3.0         evaluate_0.12      data.table_1.11.8 
## [85] modelr_0.1.2       nloptr_1.2.1       httpuv_1.4.5      
## [88] cellranger_1.1.0   gtable_0.2.0       mime_0.6          
## [91] xtable_1.8-3       later_0.7.5        viridisLite_0.3.0 
## [94] survival_2.43-1    rsconnect_0.8.8    shinythemes_1.1.2